from collections import defaultdict
import re
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import torch
from inference.utils import save_annotations_as_html
from inference.constants import ITERATIVE_DICT_TEMPLATE, SCORE_SUM_COLUMN, PROMPT_COLUMN, ORIGINAL_COLUMN, IMAGE_COLUMN, ANNOTATION_COLUMNS, ANNOTATION_SUB_FOLDER
from inference.timer_collection import TimerCollection

class PromptOptimizer:
    def __init__(self, scorer, generator, dataset, accelerator, timer_collection=None):
        self.scorer = scorer
        self.generator = generator
        self.dataset = dataset
        self.accelerator = accelerator
        self.timer_collection = TimerCollection() if timer_collection is None else timer_collection

    def save_results(self, data_dicts, current_idx, batch_size):
        if self.accelerator is not None:
            local_process_index = self.accelerator.local_process_index
        else:
            local_process_index = 0
        # Save the results
        prev_idx = current_idx
        current_idx = current_idx + batch_size
        old_filename = f"{self.scorer.output_data_path}{self.dataset.name}_process{local_process_index}_nextline_{str(prev_idx)}.csv"
        new_filename = f"{self.scorer.output_data_path}{self.dataset.name}_process{local_process_index}_nextline_{str(current_idx)}.csv"
        if os.path.isfile(old_filename):
            data_dicts = pd.read_csv(old_filename).to_dict(orient="records") + data_dicts
            os.remove(old_filename)
        pd.DataFrame(data_dicts).to_csv(new_filename, index=False)
        return current_idx
    
    def process_single_row(self, row, num_iter, process_number):
        with self.timer_collection.get_timer(f"process{process_number}_one_data_row"):
            ### FOR EACH PROMPT ###
            original_prompt, questions, choices, answers, noun_chunks = self.dataset.get_info_from_row(row)

            new_data_row = {
                "original_prompt": original_prompt,
            }
            
            if self.generator.is_scoring or self.generator.is_instant:
                if self.generator.is_scoring:
                    # Get prompt to score
                    if len(self.dataset.scoring_column) > 0:
                        prompt_to_score = str(row[self.dataset.scoring_column]).strip()
                    else:
                        prompt_to_score = original_prompt

                else: # self.generator.is_instant:
                    with self.timer_collection.get_timer(f"process{process_number}_get_best_of_or_all_paraphrases"):
                        # Use GPT-4 to select the data_args.num_paraphrases generated by open-source LLM
                        best_paraphrase = self.generator.get_best_of_or_all_paraphrases(original_prompt)
                    prompt_to_score = best_paraphrase

                with self.timer_collection.get_timer(f"process{process_number}_compute_scores_for_instant_or_scoring_mode"):
                    # Compute & save scores
                    scores_dict = self.scorer.get_all_scores(prompt_to_score, questions, choices, answers, noun_chunks, process_number)
                    for key in scores_dict:
                        new_data_row[f"{key}_score"] = scores_dict[key]

            elif self.generator.is_iterative or self.generator.is_iterative_for_annotation:
                best_paraphrase = ""
                current_prompt = original_prompt
                best_paraphrase_score = 0.0

                ### UP TO num_iter ITERATIONS ###
                for new_key in self.scorer.all_metrics_and_paraphrase:
                    new_data_row[ITERATIVE_DICT_TEMPLATE.format(new_key)] = []

                best_iter = -1

                if self.generator.is_iterative_for_annotation:
                    annotation_datapoints = dict((column, []) for column in self.scorer.all_metrics + ANNOTATION_COLUMNS)

                previous_prompts = []

                for iter_idx in range(num_iter):
                    with self.timer_collection.get_timer(f"process{process_number}_one_optimization_iteration"):
                        with self.timer_collection.get_timer(f"process{process_number}_get_best_of_or_all_paraphrases"):
                            ### Use GPT-4 to select the data_args.num_paraphrases paraphrases generated by the open-source LLM
                            paraphrases = self.generator.get_best_of_or_all_paraphrases(original_prompt, iter_idx, questions, choices, answers, noun_chunks, previous_prompts)

                        ## For each paraphrase
                        all_scores_dict = dict((metric, []) for metric in self.scorer.all_metrics)

                        ## Adding empty list for is_iterative_for_annotation mode to produce file
                        if self.generator.is_iterative_for_annotation:
                            for column in ANNOTATION_COLUMNS:
                                all_scores_dict[column] = []

                        with self.timer_collection.get_timer(f"process{process_number}_computing_scores_for_all_paraphrases"):
                            ## Add scores for original prompt for is_iterative_for_annotation mode
                            if self.generator.is_iterative_for_annotation:
                                paraphrases = [current_prompt] + paraphrases

                            ## Compute scores
                            for p_idx, paraphrase in enumerate(paraphrases):
                                scores_dict = self.scorer.get_all_scores(paraphrase, questions, choices, answers, noun_chunks, process_number, save_for_annotation=self.generator.is_iterative_for_annotation, top_logprobs=5)

                                for key in scores_dict:
                                    if key in all_scores_dict:
                                        all_scores_dict[key].append(scores_dict[key])

                                if not (self.generator.is_iterative_for_annotation and p_idx == 0):
                                    this_entry = {
                                        "revised_prompt": paraphrase,
                                        "tifa": scores_dict["tifa"],
                                        "vqa": scores_dict["vqa"],
                                        "vqa_scores": scores_dict["vqa_scores"],
                                        "score": (scores_dict["tifa"] + scores_dict["vqa"])/2,
                                        "answers": scores_dict["answers"],
                                    }

                                    previous_prompts.append(this_entry)
                                
                                ## Each row should save original prompt + paraphrase used or original prompt if that was used
                                if self.generator.is_iterative_for_annotation:
                                    all_scores_dict[PROMPT_COLUMN].append(paraphrase)
                                    all_scores_dict[ORIGINAL_COLUMN].append(current_prompt)

                        ## Restore all_scores_dict and paraphrases
                        if self.generator.is_iterative_for_annotation:
                            for key in all_scores_dict:
                                annotation_datapoints[key] = annotation_datapoints[key] + all_scores_dict[key]
                                all_scores_dict[key] = all_scores_dict[key][1:]

                        for metric in self.scorer.all_metrics:
                            new_data_row[ITERATIVE_DICT_TEMPLATE.format(metric)].append(np.max(all_scores_dict[metric]))
                        
                        current_prompt = paraphrases[np.argmax(all_scores_dict[SCORE_SUM_COLUMN])]
                        new_data_row[ITERATIVE_DICT_TEMPLATE.format("paraphrase")].append(current_prompt)

                        if np.max(all_scores_dict[SCORE_SUM_COLUMN]) >= best_paraphrase_score:
                            best_paraphrase = current_prompt
                            best_paraphrase_score = np.max(all_scores_dict[SCORE_SUM_COLUMN])
                            best_iter = iter_idx

                new_data_row["best_iter"] = best_iter
                
                if self.generator.is_iterative_for_annotation:
                    existing_df, annot_file_path = self.get_existing_annotations(process_number)

                    # Step 2: Create a new DataFrame with the rows you want to add
                    new_df = pd.DataFrame(annotation_datapoints)

                    # Step 3: Append the new DataFrame to the existing one
                    if not existing_df.empty:
                        combined_df = pd.concat([existing_df, new_df], ignore_index=True)
                    else:
                        combined_df = new_df

                    # Step 4: Write the combined DataFrame back to the CSV file
                    combined_df.to_csv(annot_file_path, index=False)

            if self.generator.is_iterative or self.generator.is_iterative_for_annotation or self.generator.is_instant:
                new_data_row["best_paraphrase"] = best_paraphrase
        return new_data_row

    def get_existing_annotations(self, process_number):
        ## Save to csv file with append (one per process) for is_iterative_for_annotation mode to produce dataframe later on
        # Step 1: Check if the output.csv file exists
        annot_file_path = f"{self.scorer.output_data_path}{ANNOTATION_SUB_FOLDER}_process{process_number}.csv"

        if os.path.exists(annot_file_path):
            # If the file exists, read it into a DataFrame
            existing_df = pd.read_csv(annot_file_path)
        else:
            # If the file does not exist, create an empty DataFrame with the appropriate columns
            existing_df = pd.DataFrame()
        return existing_df, annot_file_path

    def compile_results(self, get_scores=True, return_iterated_set=False):
        # Initialize a list to store dataframes
        dfs = []
        base_dir_files = os.listdir(self.scorer.output_data_path)
        these_files = [f for f in base_dir_files if f.startswith(self.dataset.name)]
        
        # Dictionary to store the highest number file for each process
        process_files = defaultdict(lambda: {"number": -1, "filename": ""})

        # Regex to extract process number and number
        pattern = re.compile(r"process(\d+)_nextline_(\d+)")

        for filename in these_files:
            match = pattern.search(filename)
            if match:
                process_num = int(match.group(1))
                number = int(match.group(2))
                if number > process_files[process_num]["number"]:
                    process_files[process_num] = {"number": number, "filename": filename}
                    old_filename = process_files[process_num]["filename"]
                    if os.path.exists(old_filename):
                        os.remove(old_filename)

        # Extract the final list of filenames with the highest number for each process
        final_filenames = [os.path.join(self.scorer.output_data_path, info["filename"]) for info in process_files.values() if len(info["filename"]) > 0]
        combined_file_name = f"{self.scorer.output_data_path}{self.dataset.name}_combined.csv"
        for file_name in [combined_file_name] + final_filenames:
            if os.path.exists(file_name):
                try:
                    df = pd.read_csv(file_name)
                    dfs.append(df)
                except pd.errors.EmptyDataError:
                    print("The file is empty!")
                    # Handle the case, e.g., by creating an empty DataFrame
                os.remove(file_name)

        # If there are dataframes, concatenate them and calculate the averages
        if dfs:
            # Combine all dataframes
            combined_df = pd.concat(dfs)
            combined_df = combined_df.drop_duplicates(subset=[ORIGINAL_COLUMN])
            combined_df.to_csv(combined_file_name, index=False)

            if get_scores:
                # Computing Average Scores
                if self.generator.is_iterative or self.generator.is_iterative_for_annotation:
                    metric_keys = [ITERATIVE_DICT_TEMPLATE.format(metric) for metric in self.scorer.all_metrics]
                else: # if self.generator.is_scoring or self.generator.is_instant:
                    metric_keys = [f"{metric}_score" for metric in self.scorer.all_metrics]
                avg_scores = combined_df[metric_keys]
                if self.generator.is_iterative or self.generator.is_iterative_for_annotation:
                    for column in metric_keys:
                        avg_scores[column] = avg_scores[column].apply(eval).apply(lambda x: max(x))
                avg_scores = avg_scores.mean()
                avg_scores = (avg_scores * 100).round(2)  # Convert to percentages and round to 2 decimals
                avg_scores.to_csv(f"{self.scorer.output_data_path}{self.dataset.name}_average_scores.csv", index=True)

            if return_iterated_set:
                return list(combined_df["original_prompt"])
        elif return_iterated_set:
            return []
    
    def iterate_over_indices(self, batch_size, num_iter, indices, message, process_number):
        iterated_set = self.compile_results(get_scores=False, return_iterated_set=True)
        data_dicts = []
        count = 0
        for idx in tqdm(indices, desc=message):
            row = self.dataset.rows[idx][1]
            if row["prompt"].strip() not in iterated_set:
                new_data_row = self.process_single_row(row, num_iter, process_number)
                data_dicts.append(new_data_row)
                count += 1
                if count % batch_size == 0:
                    self.save_results(data_dicts, idx, batch_size)
                    data_dicts = []
        
        if not (count % batch_size == 0):
            self.save_results(data_dicts, idx, batch_size)
            data_dicts = []

        if self.generator.is_iterative_for_annotation:
            # Save Annotations File as HTML
            existing_df, _ = self.get_existing_annotations(process_number)
            save_annotations_as_html(existing_df, self.scorer.output_data_path)

    def iterate_distributed(self, batch_size, num_iter):
        # Raise error if invalid mode
        if not (self.generator.is_scoring or self.generator.is_instant or self.generator.is_iterative or self.generator.is_iterative_for_annotation):
            raise ValueError(f"Invalid mode: {self.generator.mode}")
        
        with self.accelerator.split_between_processes(list(range(self.dataset.starting_point, self.dataset.ending_point))) as these_indices:
            self.iterate_over_indices(batch_size, num_iter, these_indices, f"process{self.accelerator.local_process_index}", self.accelerator.local_process_index)

        # Logic for the particular process to wait until all others are done
        if self.accelerator.local_process_index == 0:
            # Synchronize all processes
            if torch.distributed.is_initialized():
                try:
                    torch.distributed.barrier()
                    print("Master process waited and all processes synchronized.")
                except Exception as e:
                    print(f"Master process timeout or error: {e}")
            # This block of code will only be executed by the main process (index 0)
            # Any logic that you want to execute after all processes are done goes here
            print("All processes have reached the barrier.")
            self.compile_results()
            self.timer_collection.save_timer_results_as_csv(f"{self.scorer.output_data_path}{self.dataset.name}_timer_results.csv")

    def iterate_non_distributed(self, batch_size, num_iter):
        # Raise error if invalid mode
        if not (self.generator.is_scoring or self.generator.is_instant or self.generator.is_iterative or self.generator.is_iterative_for_annotation):
            raise ValueError(f"Invalid mode: {self.generator.mode}")
        
        these_indices = list(range(self.dataset.starting_point, self.dataset.ending_point))

        self.iterate_over_indices(batch_size, num_iter, these_indices, "Sequential processing", 0)
        self.compile_results()
        self.timer_collection.save_timer_results_as_csv(f"{self.scorer.output_data_path}{self.dataset.name}_timer_results.csv")

    def run(self, batch_size, num_iter):
        if self.accelerator is not None:
            self.iterate_distributed(batch_size, num_iter)
        else:
            self.iterate_non_distributed(batch_size, num_iter)
